import torch

from norse.torch.functional.lif_adex import (
    LIFAdExState,
    LIFAdExFeedForwardState,
    LIFAdExParameters,
    lif_adex_step,
    lif_adex_feed_forward_step,
)

from norse.torch.module.snn import SNN, SNNCell, SNNRecurrent, SNNRecurrentCell


class LIFAdExCell(SNNCell):
    r"""Computes a single euler-integration step of a feed-forward exponential
    LIF neuron-model *without* recurrence, adapted from
    http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model.
    It takes as input the input current as generated by an arbitrary torch
    module or function. More specifically it implements one integration step
    of the following ODE

    .. math::
        \begin{align*}
            \dot{v} &= 1/\tau_{\text{mem}} \left(v_{\text{leak}} - v + i + \Delta_T exp\left({{v - v_{\text{th}}} \over {\Delta_T}}\right)\right) \\
            \dot{i} &= -1/\tau_{\text{syn}} i \\
            \dot{a} &= 1/\tau_{\text{ada}} \left( a_{current} (V - v_{\text{leak}}) - a \right)
        \end{align*}

    together with the jump condition

    .. math::
        z = \Theta(v - v_{\text{th}})

    and transition equations

    .. math::
        i = i + i_{\text{in}}

    where :math:`i_{\text{in}}` is meant to be the result of applying
    an arbitrary pytorch module (such as a convolution) to input spikes.

    Parameters:
        p (LIFAdExParameters): Parameters of the LIFEx neuron model.
        dt (float): Time step to use.

    Examples:
        >>> batch_size = 16
        >>> lif_ex = LIFAdExCell()
        >>> data = torch.randn(batch_size, 20, 30)
        >>> output, s0 = lif_ex(data)
    """

    def __init__(self, p: LIFAdExParameters = LIFAdExParameters(), **kwargs):
        super().__init__(
            lif_adex_feed_forward_step,
            self.initial_state,
            p=p,
            **kwargs,
        )

    def initial_state(self, x: torch.Tensor) -> LIFAdExFeedForwardState:
        state = LIFAdExFeedForwardState(
            v=self.p.v_leak.detach(),
            i=torch.zeros(
                *x.shape,
                device=x.device,
                dtype=x.dtype,
            ),
            a=torch.zeros(
                *x.shape,
                device=x.device,
                dtype=x.dtype,
            ),
        )
        state.v.requires_grad = True
        return state


class LIFAdExRecurrentCell(SNNRecurrentCell):
    r"""Computes a single of euler-integration step of a recurrent adaptive exponential
    LIF neuron-model *with* recurrence, adapted from
    http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model.
    More specifically it implements one integration step of the following ODE

    .. math::
        \begin{align*}
            \dot{v} &= 1/\tau_{\text{mem}} \left(v_{\text{leak}} - v + i + \Delta_T exp\left({{v - v_{\text{th}}} \over {\Delta_T}}\right)\right) \\
            \dot{i} &= -1/\tau_{\text{syn}} i \\
            \dot{a} &= 1/\tau_{\text{ada}} \left( a_{current} (V - v_{\text{leak}}) - a \right)
        \end{align*}

    together with the jump condition

    .. math::
        z = \Theta(v - v_{\text{th}})

    and transition equations

    .. math::
        \begin{align*}
            v &= (1-z) v + z v_{\text{reset}} \\
            i &= i + w_{\text{input}} z_{\text{in}} \\
            i &= i + w_{\text{rec}} z_{\text{rec}}
        \end{align*}

    where :math:`z_{\text{rec}}` and :math:`z_{\text{in}}` are the
    recurrent and input spikes respectively.

    Examples:
        >>> batch_size = 16
        >>> lif = LIFAdExRecurrentCell(10, 20)
        >>> input = torch.randn(batch_size, 10)
        >>> output, s0 = lif(input)

    Parameters:
        input_size (int): Size of the input. Also known as the number of input features.
        hidden_size (int): Size of the hidden state. Also known as the number of input features.
        p (LIFAdExParameters): Parameters of the LIF neuron model.
        input_weights (torch.Tensor): Weights used for input tensors. Defaults to a random
            matrix normalized to the number of hidden neurons.
        recurrent_weights (torch.Tensor): Weights used for input tensors. Defaults to a random
            matrix normalized to the number of hidden neurons.
        autapses (bool): Allow self-connections in the recurrence? Defaults to False. Will also
            remove autapses in custom recurrent weights, if set above.
        dt (float): Time step to use.

    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        p: LIFAdExParameters = LIFAdExParameters(),
        **kwargs,
    ):
        super().__init__(
            activation=lif_adex_step,
            state_fallback=self.initial_state,
            p=p,
            input_size=input_size,
            hidden_size=hidden_size,
            **kwargs,
        )

    def initial_state(self, input_tensor: torch.Tensor) -> LIFAdExState:
        dims = (*input_tensor.shape[:-1], self.hidden_size)
        state = LIFAdExState(
            z=torch.zeros(
                *dims,
                device=input_tensor.device,
                dtype=input_tensor.dtype,
            ),
            v=torch.full(
                dims,
                self.p.v_leak.detach(),
                device=input_tensor.device,
                dtype=input_tensor.dtype,
            ),
            i=torch.zeros(
                *dims,
                device=input_tensor.device,
                dtype=input_tensor.dtype,
            ),
            a=torch.zeros(
                *dims,
                device=input_tensor.device,
                dtype=input_tensor.dtype,
            ),
        )
        state.v.requires_grad = True
        return state


class LIFAdEx(SNN):
    r"""A neuron layer that wraps a recurrent LIFAdExCell in time such
    that the layer keeps track of temporal sequences of spikes.
    After application, the layer returns a tuple containing
      (spikes from all timesteps, state from the last timestep).

    Example:
        >>> data = torch.zeros(10, 5, 2) # 10 timesteps, 5 batches, 2 neurons
        >>> l = LIFAdExLayer(2, 4)
        >>> l(data) # Returns tuple of (Tensor(10, 5, 4), LIFExState)

    Parameters:
        p (LIFAdExParameters): The neuron parameters as a torch Module, which allows the module
            to configure neuron parameters as optimizable.
        dt (float): Time step to use in integration. Defaults to 0.001.
    """

    def __init__(self, p: LIFAdExParameters = LIFAdExParameters(), **kwargs):
        super().__init__(
            activation=lif_adex_feed_forward_step,
            state_fallback=self.initial_state,
            p=p,
            **kwargs,
        )

    def initial_state(self, input_tensor: torch.Tensor) -> LIFAdExFeedForwardState:
        state = LIFAdExFeedForwardState(
            v=torch.full(
                input_tensor.shape[1:],  # Assume first dimension is time
                self.p.v_leak.detach(),
                device=input_tensor.device,
                dtype=input_tensor.dtype,
            ),
            i=torch.zeros(
                input_tensor.shape[1:],  # Assume first dimension is time
                device=input_tensor.device,
                dtype=input_tensor.dtype,
            ),
            a=torch.zeros(
                input_tensor.shape[1:],  # Assume first dimension is time
                device=input_tensor.device,
                dtype=input_tensor.dtype,
            ),
        )
        state.v.requires_grad = True
        return state


class LIFAdExRecurrent(SNNRecurrent):
    r"""A neuron layer that wraps a recurrent LIFAdExRecurrentCell in time (*with*
    recurrence) such that the layer keeps track of temporal sequences of spikes.
    After application, the layer returns a tuple containing
      (spikes from all timesteps, state from the last timestep).

    Example:
        >>> data = torch.zeros(10, 5, 2) # 10 timesteps, 5 batches, 2 neurons
        >>> l = LIFAdExRecurrent(2, 4)
        >>> l(data) # Returns tuple of (Tensor(10, 5, 4), LIFAdExState)

    Parameters:
        input_size (int): The number of input neurons
        hidden_size (int): The number of hidden neurons
        p (LIFAdExParameters): The neuron parameters as a torch Module, which allows the module
            to configure neuron parameters as optimizable.
        input_weights (torch.Tensor): Weights used for input tensors. Defaults to a random
            matrix normalized to the number of hidden neurons.
        recurrent_weights (torch.Tensor): Weights used for input tensors. Defaults to a random
            matrix normalized to the number of hidden neurons.
        autapses (bool): Allow self-connections in the recurrence? Defaults to False. Will also
            remove autapses in custom recurrent weights, if set above.
        dt (float): Time step to use in integration. Defaults to 0.001.
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        p: LIFAdExParameters = LIFAdExParameters(),
        **kwargs,
    ):
        super().__init__(
            activation=lif_adex_step,
            state_fallback=self.initial_state,
            input_size=input_size,
            hidden_size=hidden_size,
            p=p,
            **kwargs,
        )

    def initial_state(self, input_tensor: torch.Tensor) -> LIFAdExState:
        dims = (  # Remove first dimension (time)
            *input_tensor.shape[1:-1],
            self.hidden_size,
        )
        state = LIFAdExState(
            z=torch.zeros(
                *dims,
                device=input_tensor.device,
                dtype=input_tensor.dtype,
            ),
            v=torch.full(
                dims,
                self.p.v_leak.detach(),
                device=input_tensor.device,
                dtype=input_tensor.dtype,
            ),
            i=torch.zeros(
                *dims,
                device=input_tensor.device,
                dtype=input_tensor.dtype,
            ),
            a=torch.zeros(
                *dims,
                device=input_tensor.device,
                dtype=input_tensor.dtype,
            ),
        )
        state.v.requires_grad = True
        return state
